setwd("/data/home/lyang/Visium_RCTD")

dyn.load("/data/home/bioinfo/programs/hdf5-1.12.0/hdf5/lib/libhdf5_hl.so.200")
.libPaths(c("/data/home/lyang/R/x86_64-redhat-linux-gnu-library/4.1",
            "/data/home/bioinfo/R/R4.1.0/library_B3.13",
            "/usr/lib64/R/library","/usr/share/R/library") )
knitr::opts_knit$set(root.dir = "/data/home/lyang/Visium_RCTD")
setwd("/data/home/lyang/Visium_RCTD")
init_require_packages <- function(){
  library(Matrix)
  library(spacexr)
  library(Seurat)
  library(readxl)
  library(SeuratData)
  library(stringr) 
  library(patchwork)
  library(dplyr)
  library(ggplot2)
  library(SeuratDisk)
  library(doParallel)
}

init_require_packages()
## Registered S3 method overwritten by 'spatstat.geom':
##   method     from
##   print.boxx cli
## Attaching SeuratObject
## ── Installed datasets ───────────────────────────────────── SeuratData v0.2.1 ──
## ✓ ifnb      3.1.0                       ✓ stxKidney 0.1.0
## ✓ ssHippo   3.1.4
## ────────────────────────────────────── Key ─────────────────────────────────────
## ✓ Dataset loaded successfully
## > Dataset built with a newer version of Seurat than installed
## ❓ Unknown version of Seurat installed
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
## Registered S3 method overwritten by 'SeuratDisk':
##   method            from  
##   as.sparse.H5Group Seurat
## Loading required package: foreach
## Loading required package: iterators
## Loading required package: parallel
dataset_name = "cell2location"
# dataset_name = "Tabula" "cell2location"
sc_dataset_file = paste(dataset_name,"sc_dataset.rds",sep ="_")

if(file.exists(sc_dataset_file)){
  rds_file= sc_dataset_file
  sc_dataset <- readRDS(rds_file)
  }else{
  
  import_sc_dataset <- function()
  {
    # Convert("D:/minor_intership_data/sc.h5ad", dest = "h5seurat", overwrite = TRUE)
    sc_dataset <- LoadH5Seurat("D:/minor_intership_data/sc.h5seurat",assays  = "RNA")
    sc_dataset
    
    saveRDS(sc_dataset,sc_dataset_file)
    return(sc_dataset)
  }
  sc_dataset = import_sc_dataset()
  sc_dataset
  colnames(x = sc_dataset[[]])
    
}
qc_filter_sc <- function(sc_dataset,dataset_name){
  colnames(x = sc_dataset[[]])
  if(dataset_name == "cell2location"){
      sc_dataset[['nCount_RNA']] = sc_dataset$n_counts
  }else if(dataset_name == "Tabula"){
    sc_dataset[['nCount_RNA']] = sc_dataset$n_counts_UMIs
  }
 
  
  # head(sc_dataset@meta.data)
  # The [[ operator can add columns to object metadata. This is a great place to stash QC stats
  # sc_dataset[["percent.mt"]] <- PercentageFeatureSet(sc_dataset, pattern = "^MT-")
    mt.genes <- grep(pattern = '^MT-', x = rownames(x = sc_dataset@assays$RNA),value = TRUE)

  percent.mt <- Matrix::colSums(sc_dataset@assays$RNA[mt.genes, ]) / Matrix::colSums(sc_dataset@assays$RNA)

  sc_dataset <- AddMetaData(object = sc_dataset,metadata = percent.mt,
              col.name = 'percent.mt')
  
  # filter cells based on QC metrics
  sc_dataset_filtered <- subset(sc_dataset, subset = n_genes > 200 & n_genes < 5000 & percent.mt < 0.05 )
  return(list("sc_dataset"= sc_dataset,"sc_dataset_filtered"= sc_dataset_filtered))
  
}
# i filter cells that have gene number over 2,500 or less than 200
# i filter cells that have >5% mitochondrial counts
# these qc requirement comes from Seurat tutorial, and only retain 9532/53275 cells
a_list = qc_filter_sc(sc_dataset,dataset_name)
sc_dataset = a_list$sc_dataset
sc_dataset_filtered = a_list$sc_dataset_filtered

sc_dataset_filtered
## An object of class Seurat 
## 10237 features across 72566 samples within 1 assay 
## Active assay: RNA (10237 features, 0 variable features)
##  2 dimensional reductions calculated: pca, umap
# sc_dataset_filtered$CellType
# # Visualize QC metrics as a violin plot
VlnPlot(sc_dataset, features = c("n_genes", "nCount_RNA"), ncol = 2)

# crutinize which cell types are delete after filtering, compare the cell types composition before and after qc filtering, see if this changed

# dataset_name = "cell2location34" "Tabula"
dataset_name = "cell2location34"
if(dataset_name == "cell2location"){
    cell_type_table = as.data.frame(table(sc_dataset[['CellType']]))
    cell_type_table$after = as.data.frame(table(sc_dataset_filtered[['CellType']]))[,2]

}else if(dataset_name == "Tabula"){
  cell_type_table = as.data.frame(table(sc_dataset[['cell_ontology_class']]))
  cell_type_table$after = as.data.frame(table(sc_dataset_filtered[['cell_ontology_class']]))[,2]

}else if(dataset_name == "cell2location34"){
  cell_type_table = as.data.frame(table(sc_dataset[['Subset']]))
  cell_type_table$after = as.data.frame(table(sc_dataset_filtered[['Subset']]))[,2]

}


colnames(cell_type_table) = c("cell_type","before","after")


cell_type_table[,-1] = apply(cell_type_table[,-1],2,function(x) x/sum(x))
library(reshape2)
cell_type_table <- melt(cell_type_table, id.vars = "cell_type",
                        variable.name = "state",  value.name = "proportion")

ggplot(data = cell_type_table, mapping = aes(x = cell_type, 
                                             y = proportion, 
                                             colour = state,group = state )) + 
geom_point() +geom_line(aes(color=state))+
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5))

Quick data exploration:

# annotation = "cell2location34" Tabula
# annotation = "cell2location34"
dataset_name = "cell2location34"
if(dataset_name == "cell2location34"){

  cell_type_table = as.data.frame(table(sc_dataset$Subset))

}else if(dataset_name == "Tabula"){
  cell_type_table = as.data.frame(table(sc_dataset$cell_ontology_class))
}else if(dataset_name == "cell2location44"){
  cell_type_table = as.data.frame(table(sc_dataset$CellType))
}




colnames(cell_type_table) = c("cell_type","frequency")
library(ggplot2)
# Basic barplot
p <- ggplot(data=cell_type_table, aes(x=cell_type, y=frequency,fill=cell_type)) +
  geom_bar(stat="identity")+
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5))+
       geom_text(aes(label=frequency), position=position_dodge(width=0.9), vjust=-0.25)


# ggsave(p,filename = "cell_type_histogram.pdf",width = 15,
#        height = 20)

print(p)

delete_rare_cells <- function(sc_dataset_filtered)
{
  colnames(x = sc_dataset_filtered[[]])
  
  if(dataset_name == "cell2location"){
  Idents(sc_dataset_filtered) <- 'CellType'
  }else if(dataset_name == "Tabula"){
  Idents(sc_dataset_filtered) <- 'cell_ontology_class'
  }else if(dataset_name == "cell2location34"){
  Idents(sc_dataset_filtered) <- 'Subset'
  }


  
  rare_cell_types = names(table(Idents(sc_dataset_filtered))[!table(Idents(sc_dataset_filtered)) > 25])
  unique(Idents(sc_dataset_filtered))
  # delete cells belong to rare cell types
  for (char in rare_cell_types)
  {
      if(dataset_name == "cell2location"){
        sc_dataset_filtered <- subset(sc_dataset_filtered, subset = CellType != char)  
      }else if(dataset_name == "Tabula"){
        sc_dataset_filtered <- subset(sc_dataset_filtered, subset = cell_ontology_class != char)  
      }else if(dataset_name == "cell2location34"){
                sc_dataset_filtered <- subset(sc_dataset_filtered, subset = Subset != char)  
      }
    
  }  
  
  table(Idents(sc_dataset_filtered))
  
  return(sc_dataset_filtered)
}
sc_dataset_filtered = delete_rare_cells(sc_dataset_filtered)
sc_dataset_filtered
## An object of class Seurat 
## 10237 features across 72548 samples within 1 assay 
## Active assay: RNA (10237 features, 0 variable features)
##  2 dimensional reductions calculated: pca, umap
sc_dataset.downsampled_file = paste(dataset_name,"sc_dataset.downsampled.rds",sep ="_")


if (file.exists(sc_dataset.downsampled_file)){
  rds_file = sc_dataset.downsampled_file
  sc_dataset_filtered.downsampled <- readRDS(rds_file)
}else
{
  
subsample_cells <- function(sc_dataset)
{
  if(dataset_name == "cell2location"){
  Idents(sc_dataset) <- 'CellType'
  }else if(dataset_name == "Tabula"){
  Idents(sc_dataset) <- 'cell_ontology_class'
  }else if(dataset_name == "cell2location34"){
  Idents(sc_dataset) <- 'Subset'
  }
  
  table(Idents(sc_dataset))
  unique(Idents(sc_dataset))
  
  
  cell.list <- WhichCells(sc_dataset, idents = unique(Idents(sc_dataset)), 
                          downsample = 3000)
  sc_dataset.downsampled <- sc_dataset[, cell.list]
  # table(sc_dataset.downsampled$CellType)
  return(sc_dataset.downsampled)
  
}
# subsample fixed numbers of cells from differently sized cell types in a seurat object.
sc_dataset_filtered.downsampled = subsample_cells(sc_dataset_filtered)
saveRDS(sc_dataset_filtered.downsampled,sc_dataset.downsampled_file)


}
Create_Reference_object <- function(sc_dataset_filtered)
{
  
  # load in counts matrix
  counts = as.matrix(sc_dataset_filtered@assays$RNA@counts)
  
  # counts = as.matrix(x = GetAssayData(object = sc_dataset.downsampled, assay = "RNA", slot = "counts"))
  # downsampled = SampleUMI(data = counts, max.umi = 1000, upsample = TRUE, verbose = TRUE)
  colnames(x = sc_dataset_filtered[[]])
  #rownames(counts) <- counts[,1]; counts[,1] <- NULL # Move first column to rownames
  if(dataset_name == "cell2location"){
  meta_data <- sc_dataset_filtered@meta.data[,c("CellType",'nCount_RNA')] 
    cell_types <- meta_data$CellType
  }else if(dataset_name == "Tabula"){
  meta_data <- sc_dataset_filtered@meta.data[,c("cell_ontology_class",'nCount_RNA')]
    cell_types <- meta_data$cell_ontology_class
  }else if(dataset_name == "cell2location34"){
  meta_data <- sc_dataset_filtered@meta.data[,c("Subset",'nCount_RNA')]
    cell_types <- meta_data$Subset
  }
  
# load in meta_data (barcodes, clusters, and nUMI)

  cell_types = str_replace_all( cell_types,"/", "_") 
  names(cell_types) <- rownames(meta_data)#meta_data$barcode # create cell_types named list
  cell_types <- as.factor(cell_types) # convert to factor data type
  cell_types <- droplevels(cell_types)
  nUMI <- meta_data$nCount_RNA; names(nUMI) <- rownames(meta_data)#meta_data$barcode # create nUMI named list
  ### Create the Reference object
  reference <- Reference(counts, cell_types, nUMI)
  ## Examine reference object (optional)
  print(dim(reference@counts)) 
  # gene number; cell number
  table(reference@cell_types) 
  return(reference)
}

reference = Create_Reference_object(sc_dataset_filtered.downsampled)
## Warning in if (class(counts) != "dgCMatrix") {: the condition has length > 1 and
## only the first element will be used
## Warning in if (class(counts) != "matrix") tryCatch({: the condition has length >
## 1 and only the first element will be used
## [1] 10237 47443
spatial_dataset_name = "Visium"
puck_file = paste(spatial_dataset_name,"puck.rds",sep ="_")

if (file.exists(puck_file)){
  rds_file = puck_file
  puck <- readRDS(rds_file)
}else{
  creat_spatial_RCTD <- function(){
  rds_file = "Visium_spatial.rds"
  spatial <- readRDS(rds_file)

  #https://github.com/dmcable/RCTD/issues/26
  # coords <- spatial@images$slice1@image
  coords <-GetTissueCoordinates(spatial, scale = NULL) 
  # pulling unscaled tissue coordinates
  #spatialcounts <- as.matrix(GetAssayData(spatial, assay = "Spatial", slot = "counts"))
  spatialcounts <- as.matrix(spatial@assays$Spatial@counts)
  ### Create SpatialRNA object
  puck <- SpatialRNA(coords, spatialcounts)
  saveRDS(puck,puck_file)
  return(puck)
}

  puck = creat_spatial_RCTD()
}
library(quadprog)

RCTD_file = paste(dataset_name,"full_RCTD.rds",sep ="_")
# RCTD_file = paste(dataset_name,"multi_RCTD.rds",sep ="_")
# RCTD_file = paste(dataset_name,"doublet_RCTD.rds",sep ="_")

if(file.exists(RCTD_file)){
  rds_file = RCTD_file
  myRCTD <- readRDS(rds_file)  
}else{
  myRCTD <- create.RCTD(puck, reference, max_cores = 18)
  myRCTD <- run.RCTD(myRCTD, doublet_mode = 'full')
  # myRCTD <- run.RCTD(myRCTD, doublet_mode = 'multi')
  # myRCTD2 <- run.RCTD(myRCTD, doublet_mode = 'doublet')
  saveRDS(myRCTD,RCTD_file)  
}

Including Plots

You can also embed plots, for example:

show_results <- function(myRCTD)
{
  results <- myRCTD@results
  # normalize the cell type proportions to sum to 1.
  norm_weights = normalize_weights(results$weights) 
  head(norm_weights)
  
  cell_type_names <- myRCTD@cell_type_info$info[[2]] #list of cell type names
  spatialRNA <- myRCTD@spatialRNA
  resultsdir <- 'Tabula_Visium_RCTD_Plots' ## you may change this to a more accessible directory on your computer. 
  dir.create(resultsdir)
  #> Warning in dir.create(resultsdir): 'RCTD_Plots' already exists
  # make the plots 
  # Plots the confident weights for each cell type as in full_mode (saved as 
  # 'results/cell_type_weights_unthreshold.pdf')
  plot_weights(cell_type_names, spatialRNA, resultsdir, norm_weights) 
  
  # Plots all weights for each cell type as in full_mode. (saved as 
  # 'results/cell_type_weights.pdf')
  plot_weights_unthreshold(cell_type_names, spatialRNA, resultsdir, norm_weights) 
  
  
  
  # Plots the number of  pixels of each cell type where they confidently located in under 'full_mode'. (saved as 
  # 'results/cell_type_occur.pdf')
  plot_cond_occur(cell_type_names, resultsdir, norm_weights, spatialRNA)
  return(norm_weights)
  
}
norm_weights_object = show_results(myRCTD)

creat_spatial_RCTD_GeoMx <- function()
{
  coords <- read_excel("D:/minor_intership_data/GeoMx WTA Human Lymph Node Data_count_results/Export4_NormalizationQ3.xlsx ",sheet = "SegmentProperties") 
  coords <- as.data.frame(coords[,c("SegmentDisplayName","ROICoordinateX","ROICoordinateY")]) 
  rownames(coords) = coords[,"SegmentDisplayName"]
  coords$SegmentDisplayName <- NULL
  counts <- read_excel("D:/minor_intership_data/GeoMx WTA Human Lymph Node Data_count_results/Export4_NormalizationQ3.xlsx ",sheet = "TargetCountMatrix") 
  counts <- as.data.frame(counts)
  rownames(counts) = counts[,1]
  counts$TargetName <- NULL
  puck <- SpatialRNA(coords, counts,require_int = FALSE)
  
  return(puck)
}
# puck = creat_spatial_RCTD_GeoMx()
# saveRDS(puck,'spatial_GeoMx.rds')
# rds_file='spatial_GeoMx.rds'
# puck <- readRDS(rds_file)

Note that the echo = FALSE parameter was added to the code chunk to prevent printing of the R code that generated the plot.

library(ggplot2)
library(SPOTlight)
normalized_weights <- normalize_weights(myRCTD@results$weights)
# normalize the cell type proportions to sum to 1.
class(normalized_weights)
## [1] "dgeMatrix"
## attr(,"package")
## [1] "Matrix"
norm_weights_matrix = as(normalized_weights,'matrix')
# mat <- res$mat
mat <- norm_weights_matrix
ct <- colnames(mat)
mat[mat < 0.1] <- 0


library(RColorBrewer)


if(dataset_name == "Tabula"){
  n <- 29
}else if(dataset_name == "cell2location"){
  n <- 44
}else if(dataset_name == "cell2location34"){
  n <- 34
}


colrs <- brewer.pal.info[brewer.pal.info$colorblind == TRUE, ]
col_vec = unlist(mapply(brewer.pal, colrs$maxcolors, rownames(colrs)))
col <- sample(col_vec, n)
rds_file = "Visium_spatial.rds"
spatial <- readRDS(rds_file)

test = spatial@assays$Spatial@counts@Dimnames[[2]] %in% rownames(norm_weights_matrix)
which(test == FALSE)
## [1]  464 1343
spatial@assays$Spatial@counts@Dimnames[[2]][464]
## [1] "ACTGTAGCACTTTGGA-1"
spatial@assays$Spatial@counts@Dimnames[[2]][1343]
## [1] "CCCGCCATGCTCCCGT-1"
spatial@meta.data$spot_name  <- rownames(spatial@meta.data)
spatial = subset(spatial,spot_name !="ACTGTAGCACTTTGGA-1" )
spatial = subset(spatial,spot_name !="CCCGCCATGCTCCCGT-1")


# Define color palette
# (here we use 'paletteMartin' from the 'colorBlindness' package)
paletteMartin <- col

pal <- colorRampPalette(paletteMartin)(length(ct))
names(pal) <- ct


pal_back <- pal


plot_3_region <- function(pal)
{
  for (char in names(pal)) {
  print(char)
  if(char %in% c("T_CD4+_naive","FDC","B_naive") ){
    if(char == "T_CD4+_naive")
    {
      pal[char] = "#FFFF00"
    } else if (char == "FDC")
    {
      pal[char] = "#add8e6"
    } else if (char == "B_naive")
    {
      pal[char] = "#FF0000"
    }
    next
  }
  pal[char] = "#00008b"
}

  return(pal)
}
packageVersion("SPOTlight")
## [1] '0.99.4'
pal = plot_3_region(pal_back)
## [1] "B_activated"
## [1] "B_Cycling"
## [1] "B_GC_DZ"
## [1] "B_GC_LZ"
## [1] "B_GC_prePB"
## [1] "B_IFN"
## [1] "B_mem"
## [1] "B_naive"
## [1] "B_plasma"
## [1] "B_preGC"
## [1] "DC_CCR7+"
## [1] "DC_cDC1"
## [1] "DC_cDC2"
## [1] "DC_pDC"
## [1] "Endo"
## [1] "FDC"
## [1] "ILC"
## [1] "Macrophages_M1"
## [1] "Macrophages_M2"
## [1] "Monocytes"
## [1] "NK"
## [1] "NKT"
## [1] "T_CD4+"
## [1] "T_CD4+_naive"
## [1] "T_CD4+_TfH"
## [1] "T_CD4+_TfH_GC"
## [1] "T_CD8+_CD161+"
## [1] "T_CD8+_cytotoxic"
## [1] "T_CD8+_naive"
## [1] "T_TfR"
## [1] "T_TIM3+"
## [1] "T_Treg"
## [1] "VSMC"
plotSpatialScatterpie(
    x = spatial,
    y = mat,
  cell_types = colnames(y),
  img = FALSE,
  scatterpie_alpha = 1,
  pie_scale = 0.4) +
  scale_fill_manual(
    values = pal,
    breaks = names(pal))

plot_n_cell_type <- function(pal,col_vec,n, cell_type_vector)
{
  for (char in names(pal)) {
  # print(char)
  if(char %in%  cell_type_vector){
    next
  }
  pal[char] = "#00008b"
}

  return(pal)
}

cell_type_vector = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC")

pal = plot_n_cell_type(pal_back,col_vec,6,cell_type_vector)
p1 = plotSpatialScatterpie(
    x = spatial,
    y = mat,
  cell_types = colnames(y),
  img = FALSE,
  scatterpie_alpha = 1,
  pie_scale = 0.4) +
  scale_fill_manual(
    values = pal,
    breaks = names(pal))

pal = pal_back
# type(pal)
plotSpatialScatterpie(
    x = spatial,
    y = mat,
  cell_types = colnames(y),
  img = FALSE,
  scatterpie_alpha = 1,
  pie_scale = 0.4) +
  scale_fill_manual(
    values = pal,
    breaks = names(pal))

rds_file='predictions.assay.rds'
predictions.assay <- readRDS(rds_file)

results <- myRCTD@results
# normalize the cell type proportions to sum to 1.
norm_weights = normalize_weights(results$weights) 

predictions.assay@data <- t(as.matrix(norm_weights))       
meta.features <- as.data.frame(colnames(as.matrix(norm_weights)) )
rownames(meta.features) = meta.features[,1]
meta.features$`colnames(as.matrix(norm_weights))` =NULL
predictions.assay@meta.features = meta.features

spatial[["predictions"]] <- predictions.assay

DefaultAssay(spatial) <- "predictions"
colnames(mat)
##  [1] "B_activated"      "B_Cycling"        "B_GC_DZ"          "B_GC_LZ"         
##  [5] "B_GC_prePB"       "B_IFN"            "B_mem"            "B_naive"         
##  [9] "B_plasma"         "B_preGC"          "DC_CCR7+"         "DC_cDC1"         
## [13] "DC_cDC2"          "DC_pDC"           "Endo"             "FDC"             
## [17] "ILC"              "Macrophages_M1"   "Macrophages_M2"   "Monocytes"       
## [21] "NK"               "NKT"              "T_CD4+"           "T_CD4+_naive"    
## [25] "T_CD4+_TfH"       "T_CD4+_TfH_GC"    "T_CD8+_CD161+"    "T_CD8+_cytotoxic"
## [29] "T_CD8+_naive"     "T_TfR"            "T_TIM3+"          "T_Treg"          
## [33] "VSMC"
if(dataset_name == "Tabula"){
  
p = SpatialFeaturePlot(spatial, features = c("stromal cell", "t cell"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)
print(p)
p = SpatialFeaturePlot(spatial, features = c("naive b cell", "b cell"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)
print(p)
  
}else if(dataset_name == "cell2location"){
  
 SpatialFeaturePlot(spatial, features = c("T_Treg", "DC_cDC2"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)
SpatialFeaturePlot(spatial, features = c("DC_cDC1", "DC_cDC2"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)

}else if(dataset_name == "cell2location34"){
  
# detach("package:SpatialExperiment", unload=TRUE)
  
p = SpatialFeaturePlot(spatial, features = c("B_GC_LZ", "T_CD4+_TfH_GC","B_GC_prePB","FDC"), pt.size.factor = 1.6, ncol = 4, crop = TRUE) + ggtitle("germinal center light zone")
print(p)
 # + scale_fill_continuous(limits = c(0, 1))

p = SpatialFeaturePlot(spatial, features = c("B_Cycling", "B_GC_DZ"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)+ ggtitle("germinal center dark zone") 
print(p)
SpatialFeaturePlot(spatial, features = c("B_naive", "B_preGC"), pt.size.factor = 1.6, ncol = 2, crop = TRUE) + ggtitle("B follicle + pre GC") 
}

# table(sc_dataset@meta.data$Subset)
# Leave out selected cell types (B_GC_DZ, B_GC_LZ, B_GC_prePB, T_CD4+_TfH_GC, B_Cycling and FDC, these six subtypes are expected to be present in annotated GC (positive locations) one by one using ‘subset’ command.
# for (selected_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
#                  "B_Cycling",  "FDC") )
# {
#   # selected_type = "B_GC_DZ"
#   Idents(sc_dataset_filtered.downsampled) <- 'Subset'
#   sc_reference <- subset(sc_dataset_filtered.downsampled, subset = Subset != selected_type)
#   table(sc_reference@meta.data$Subset)
#   reference = Create_Reference_object(sc_reference)
#   
#   RCTD_file = paste(dataset_name,"full_RCTD", selected_type, ".rds",sep ="_")
#   library(quadprog)
#   myRCTD <- create.RCTD(puck, reference, max_cores = 18)
#   myRCTD <- run.RCTD(myRCTD, doublet_mode = 'full')
#   saveRDS(myRCTD,RCTD_file)
#   print(RCTD_file)
# }
# getwd()
library(ggplot2)
library(SPOTlight)

manual_GC_annotation = read.csv("/data/home/lyang/scripts_git/manual_GC_annotation.csv")
rds_file = "Visium_spatial.rds"
spatial <- readRDS(rds_file)
spatial@meta.data$Barcode = rownames(spatial@meta.data)
spatial@meta.data = merge(spatial@meta.data, manual_GC_annotation, by = "Barcode", sort = FALSE)

spatial@meta.data[spatial@meta.data$GC == '',"GC"] = "no"

p2 <- SpatialDimPlot(spatial, group.by = "GC") + scale_y_reverse()
# p1 + p2
print(p2)

print(p1)

for (selected_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC") )
{
  RCTD_file = paste(dataset_name,"full_RCTD", selected_type, ".rds",sep ="_")
  rds_file = RCTD_file
  myRCTD_deleted <- readRDS(rds_file)   
  normalized_weights_deleted <- normalize_weights(myRCTD_deleted@results$weights)
  # normalize the cell type proportions to sum to 1.
  matrix_result_deleted = as(normalized_weights_deleted,'matrix')
  mat <- matrix_result_deleted
  ct <- colnames(mat)
  mat[mat < 0.1] <- 0
  
  
  library(RColorBrewer)
  
  n <- 34
  
  
  
  colrs <- brewer.pal.info[brewer.pal.info$colorblind == TRUE, ]
  col_vec = unlist(mapply(brewer.pal, colrs$maxcolors, rownames(colrs)))
  col <- sample(col_vec, n)
  rds_file = "Visium_spatial.rds"
  spatial <- readRDS(rds_file)
  
  # test = spatial@assays$Spatial@counts@Dimnames[[2]] %in% rownames(norm_weights_matrix)
  # which(test == FALSE)
  # spatial@assays$Spatial@counts@Dimnames[[2]][464]
  # spatial@assays$Spatial@counts@Dimnames[[2]][1343]
  spatial@meta.data$spot_name  <- rownames(spatial@meta.data)
  spatial = subset(spatial,spot_name !="ACTGTAGCACTTTGGA-1" )
  spatial = subset(spatial,spot_name !="CCCGCCATGCTCCCGT-1")
  
  
  # Define color palette
  # (here we use 'paletteMartin' from the 'colorBlindness' package)
  paletteMartin <- col
  
  pal <- colorRampPalette(paletteMartin)(length(ct))
  names(pal) <- ct
  
  
  pal_back <- pal
  
  plot_n_cell_type <- function(pal,col_vec,n, cell_type_vector)
  {
    for (char in names(pal)) {
    # print(char)
    if(char %in%  cell_type_vector){
      next
    }
    pal[char] = "#00008b"
  }
  
    return(pal)
  }
  
  cell_type_vector = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                   "B_Cycling",  "FDC")
  
  pal = plot_n_cell_type(pal_back,col_vec,6,cell_type_vector)
  p1 = plotSpatialScatterpie(
      x = spatial,
      y = mat,
    cell_types = colnames(y),
    img = FALSE,
    scatterpie_alpha = 1,
    pie_scale = 0.4) +
    scale_fill_manual(
      values = pal,
      breaks = names(pal))
  print(p1)  
}

rmsd_summary = rep(0,6)
names(rmsd_summary) = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC")


for (selected_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC") )
{
  matrix_result_complete = norm_weights_matrix
  RCTD_file = paste(dataset_name,"full_RCTD", selected_type, ".rds",sep ="_")
  rds_file = RCTD_file
  myRCTD_deleted <- readRDS(rds_file)   
  normalized_weights_deleted <- normalize_weights(myRCTD_deleted@results$weights)
  # normalize the cell type proportions to sum to 1.
  matrix_result_deleted = as(normalized_weights_deleted,'matrix')
  matrix_result_deleted = as.data.frame(matrix_result_deleted)
  matrix_result_deleted$selected_type = rep(0,nrow(matrix_result_deleted))
  selected_column = matrix_result_complete[,selected_type]
  matrix_result_complete = as.data.frame(matrix_result_complete)
  df = matrix_result_complete[,!(names(matrix_result_complete) == selected_type)]
  df$selected_type = selected_column
  matrix_result_complete = as.matrix(df)
  matrix_result_deleted = as.matrix(matrix_result_deleted)
  
  
  # for each spatial spot, Calculate the root-mean-square error (RMSE) distance between the predicted proportion under different reference dataset situations
  # load("/data/home/lyang/Visium_RCTD/.RData")
  
  
  matrix_substrated = matrix_result_complete - matrix_result_deleted
  rmsd_per_spot = sqrt(rowSums(matrix_substrated^2)/ncol(matrix_substrated))
  rmsd_per_situation = sum(rmsd_per_spot)
  rmsd_summary[selected_type] = rmsd_per_situation
  
}


df <- data.frame(deleted_cell_type=names(rmsd_summary),
                RMSD=rmsd_summary  )
p<-ggplot(data=df, aes(x=deleted_cell_type, y=RMSD)) +
  geom_bar(stat="identity")
p